package com.young.common.core.dal.page;

import com.young.common.util.Page;
import com.young.common.util.StringUtils;
import com.young.common.core.dal.DalRowBounds;
import com.young.common.util.ReflectUtil;
import org.apache.ibatis.executor.parameter.ParameterHandler;
import org.apache.ibatis.executor.statement.RoutingStatementHandler;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.apache.ibatis.reflection.factory.DefaultObjectFactory;
import org.apache.ibatis.reflection.factory.ObjectFactory;
import org.apache.ibatis.reflection.wrapper.DefaultObjectWrapperFactory;
import org.apache.ibatis.reflection.wrapper.ObjectWrapperFactory;
import org.apache.ibatis.scripting.defaults.DefaultParameterHandler;
import org.apache.ibatis.session.RowBounds;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.List;
import java.util.Properties;

/**
 *
 * 分页拦截器，用于拦截需要进行分页查询的操作，然后对其进行分页处理。
 * 利用拦截器实现Mybatis分页的原理：
 * 要利用JDBC对数据库进行操作就必须要有一个对应的Statement对象，Mybatis在执行Sql语句前就会产生一个包含Sql语句的Statement对象，而且对应的Sql语句
 * 是在Statement之前产生的，所以我们就可以在它生成Statement之前对用来生成Statement的Sql语句下手。在Mybatis中Statement语句是通过RoutingStatementHandler对象的
 * prepare方法生成的。所以利用拦截器实现Mybatis分页的一个思路就是拦截StatementHandler接口的prepare方法，然后在拦截器方法中把Sql语句改成对应的分页查询Sql语句，之后再调用
 * StatementHandler对象的prepare方法，即调用invocation.proceed()。
 * 对于分页而言，在拦截器里面我们还需要做的一个操作就是统计满足当前条件的记录一共有多少，这是通过获取到了原始的Sql语句后，把它改为对应的统计语句再利用Mybatis封装好的参数和设
 * 置参数的功能把Sql语句中的参数进行替换，之后再执行查询记录数的Sql语句进行总记录数的统计。
 *
 */
@Intercepts({@Signature(method = "prepare", type = StatementHandler.class, args = {Connection.class}) })
public class PageInterceptorPlugin implements Interceptor {

    Logger logger = LoggerFactory.getLogger(PageInterceptorPlugin.class);
    private static final ObjectFactory DEFAULT_OBJECT_FACTORY = new DefaultObjectFactory();
    private static final ObjectWrapperFactory DEFAULT_OBJECT_WRAPPER_FACTORY = new DefaultObjectWrapperFactory();

    //分页sql构建对象,不同数据库对应不同的构建器
    private IPageSqlBuilder pageSqlBuilder = null;

    /**
     * 拦截后要执行的方法
     */
    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        this.before(invocation);
        return invocation.proceed();
    }

    /**
     * 拦截器对应的封装原始对象的方法
     */
    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }
    /**
     * 设置注册拦截器时设定的属性
     * 在mybatis-config中配置属性在此可以接受
     */
    @Override
    public void setProperties(Properties properties) {
        // 解析传入的参数
        if (properties != null && StringUtils.isNotBlank(properties.getProperty("db"))){
            this.pageSqlBuilder = PageSqlBuilderFactory.getBuilder(properties.getProperty("db"));
        }else{
            this.pageSqlBuilder = PageSqlBuilderFactory.getDefaultBuilder();
        }
    }

    private void before(Invocation invocation){
        StatementHandler statementHandler = (StatementHandler) invocation
                .getTarget();

        MetaObject metaStatementHandler = MetaObject.forObject(
                statementHandler, DEFAULT_OBJECT_FACTORY,
                DEFAULT_OBJECT_WRAPPER_FACTORY);

        RowBounds rowBounds = (RowBounds) metaStatementHandler
                .getValue("delegate.rowBounds");

        RoutingStatementHandler handler = (RoutingStatementHandler) invocation
                .getTarget();

        StatementHandler delegate = (StatementHandler) ReflectUtil
                .getFieldValue(handler, "delegate");

        BoundSql boundSql = delegate.getBoundSql();

        String sql = boundSql.getSql();
        MappedStatement mappedStatement = (MappedStatement) ReflectUtil
                .getFieldValue(delegate, "mappedStatement");
        //logger.info("SQLID:"+mappedStatement.getId()+",SQL:"+sql+",条件："+(boundSql.getParameterObject()==null?null:boundSql.getParameterObject().toString()));
        // 这里我们简单的通过传入的是Page对象就认定它是需要进行分页操作的。
        if (rowBounds instanceof DalRowBounds) {

            Page<?> page = ((DalRowBounds) rowBounds).getPage();
            if(null!=page){
                //如果已经赋值count，不再重新计算，对Count计算特殊要求的，可以先对page的count计算后赋值
                if(page.getTotal()==0){
                    Connection connection = (Connection) invocation.getArgs()[0];
                    this.setTotalRecord(page,boundSql.getParameterObject(), mappedStatement, connection);
                }
                String pageSql = this.getPageSql(page, sql);
                //logger.info("pageSQL:"+pageSql);
                ReflectUtil.setFieldValue(boundSql, "sql", pageSql);
            }
        }
    }
    //分页语句
    private String getPageSql(Page<?> page, String sql) {
        //StringBuilder pageSql = new StringBuilder(100);
        //oracle
        /*String beginrow = String.valueOf((page.getPageNumber() - 1) * page.getPageSize());
        String endrow = String.valueOf(page.getPageNumber() * page.getPageSize());
        pageSql.append("select * from ( select temp.*, rownum row_id from ( ");
        pageSql.append(sql);
        pageSql.append(" ) temp where rownum <= ").append(endrow);
        pageSql.append(") where row_id > ").append(beginrow);*/
        //mysql
        /*pageSql.append(sql);
        pageSql.append(" LIMIT " + ((page.getPageNumber() - 1) * page.getPageSize()) + "," + page.getPageSize());*/
        //postgres
        /*pageSql.append(sql);
        pageSql.append(" LIMIT " + page.getPageSize() + " offset " + ((page.getPageNumber() - 1) * page.getPageSize()) );*/
        return this.pageSqlBuilder.getPageSql(sql, page.getPageNumber(), page.getPageSize());
    }

    //设置总条数
    private void setTotalRecord(Page<?> page,Object parmeter, MappedStatement mappedStatement,
                                Connection connection) {
        BoundSql boundSql = mappedStatement.getBoundSql(parmeter);
        //获取到配置文件中的原始SQL
        //String sql = boundSql.getSql();
        String countSql = this.pageSqlBuilder.getCountSql(boundSql.getSql());//"select count(*) total_count from (" + boundSql.getSql() + ") temp_count";//parser.getCountSql(sql);

        List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();

        BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(), countSql, parameterMappings, parmeter);
        MetaObject countBsObject = SystemMetaObject.forObject(countBoundSql);
        MetaObject boundSqlObject = SystemMetaObject.forObject(boundSql);
        countBsObject.setValue("metaParameters",boundSqlObject.getValue("metaParameters"));


        ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, parmeter, countBoundSql);

        PreparedStatement pstmt = null;

        ResultSet rs = null;
        try {
            pstmt = connection.prepareStatement(countSql);

            parameterHandler.setParameters(pstmt);

            rs = pstmt.executeQuery();

            if (rs.next()) {
                int totalRecord = rs.getInt(1);
                //给当前的参数page对象设置总记录数
                page.setTotal(totalRecord);
            }
        } catch (SQLException e) {
            logger.error("[mybatis分页插件] 设置总记录数阶段发生异常", e);
        } finally {
            try {
                if (rs != null)
                    rs.close();
                if (pstmt != null)
                    pstmt.close();
            } catch (SQLException e) {
                logger.error("[mybatis分页插件] 关闭jdbc发生异常", e);
            }
        }
    }

}
